from skimage import io
import os
import random
from PIL import Image

class Data(object):

    def read_image(self, filename, batch_size, shuffle = False):
        imgs_path = []
        for x in os.listdir(filename):
            imgs_path.append(filename + '\\' +x)

        # 打乱训练集
        if shuffle:
            random.shuffle(imgs_path)

        imgs = []
        for index, im in enumerate(imgs_path):
            img = io.imread(im)
            imgs.append(img)
            if (index+1) % batch_size == 0:
                yield (index+1) / batch_size, imgs
                imgs.clear()


    # 保存图片
    def save_image(self, img, folder_path, name, mode):
        img = Image.fromarray(img.astype('uint8')).convert(mode)
        img.save(os.path.join(folder_path, name + '.jpg'))